import os

import matplotlib.pyplot as plt

def generate_and_save_image(model, epoch, test_input, save_path):
    predictions = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for idx in range(predictions.shape[0]):
        plt.subplot(4, 4, idx + 1)
        plt.imshow((predictions[idx] + 1) / 2, cmap='gray')
        plt.axis('off')
    plt.savefig(os.path.join(save_path, 'image_epoch_%d.png'%(epoch)))